from sfi import Data
from sfi import SFIToolkit as stata
import numpy as np
import pandas as p
import matplotlib.pyplot as plt
import seaborn as sns
import sys

#Import data from stata
def importData(cols):
    data = p.DataFrame(np.array(Data.get(cols, missingval=np.nan)), columns=cols)
    return data


#Columns to import
cols = ['state', 'year', 'safetyclass', 'nosafetyind', 'spec1Res', 'spec2Res', 'spec3Res', 'spec4Res']
data = importData(cols)

#Various model specifications
models = ['Spec1', 'Spec2', 'Spec3', 'Spec4']

#Establish treatment statuses
data['Treatment Status'] = 'Not Treated'
data.loc[data['nosafetyind']==str(1), 'Treatment Status'] = 'Treated'
data.loc[data['safetyclass']!='Had Safety Inspections', 'Treatment Status'] = 'Never Treated'
#Ensure numeric yer
data['year'] = data['year'].astype(float)

#Establish colors for plotting residuals
hueColors = {'Not Treated':'xkcd:red', 'Treated':'xkcd:green', 'Never Treated':'xkcd:gray'}
#map raw model names to clean names
modelNames = {'Spec1':'Spec. 1', 'Spec2':'Spec. 2', 'Spec3':'Spec. 3',  'Spec4':'Spec. 4'}
for model in models:
    #Establish save file
    if (sys.platform == 'darwin'):
        saveFile = 'Plots/ResidualPlots/{}ResPlot.png'.format(model)
    else:
        saveFile = 'Plots\ResidualPlots\{}ResPlot.png'.format(model)

    #Dimensions for plots
    x = 10
    gr = (1+np.sqrt(5))/2
    y = x/gr

    #Cast residuals to numerics
    data['{}Res'.format(model.lower())] = data['{}Res'.format(model.lower())].astype(float)

    #Build plot
    fig, ax = plt.subplots(1,1, figsize=(x,y))

    ax = sns.lineplot(x='year', y='{}Res'.format(model.lower()), hue='Treatment Status', units='state', data=data, estimator=None, palette=hueColors)

    #Formatting and output
    plt.ylabel('Residual')
    plt.xlabel('Year')
    plt.title('{} Residuals vs. Time by Treatment Status'.format(modelNames[model]))
    plt.savefig(saveFile, bbox_inches='tight')